package com.alibaba.jstorm.message.netty;

import java.util.ArrayList;

import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBufferOutputStream;
import org.jboss.netty.buffer.ChannelBuffers;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import backtype.storm.messaging.TaskMessage;

class MessageBatch {
	private static final Logger LOG = LoggerFactory
			.getLogger(MessageBatch.class);
	private int buffer_size;
	private ArrayList<Object> msgs;
	private int encoded_length;

	MessageBatch(int buffer_size) {
		this.buffer_size = buffer_size;
		msgs = new ArrayList<Object>();
		encoded_length = ControlMessage.EOB_MESSAGE.encodeLength();
	}

	void add(Object obj) {
		if (obj == null)
			throw new RuntimeException("null object forbidded in message batch");

		if (obj instanceof TaskMessage) {
			TaskMessage msg = (TaskMessage) obj;
			msgs.add(msg);
			encoded_length += msgEncodeLength(msg);
			return;
		}

		if (obj instanceof ControlMessage) {
			ControlMessage msg = (ControlMessage) obj;
			msgs.add(msg);
			encoded_length += msg.encodeLength();
			return;
		}

		throw new RuntimeException("Unsuppoted object type "
				+ obj.getClass().getName());
	}

	void remove(Object obj) {
		if (obj == null)
			return;

		if (obj instanceof TaskMessage) {
			TaskMessage msg = (TaskMessage) obj;
			msgs.remove(msg);
			encoded_length -= msgEncodeLength(msg);
			return;
		}

		if (obj instanceof ControlMessage) {
			ControlMessage msg = (ControlMessage) obj;
			msgs.remove(msg);
			encoded_length -= msg.encodeLength();
			return;
		}
	}

	Object get(int index) {
		return msgs.get(index);
	}

	/**
	 * try to add a TaskMessage to a batch
	 * 
	 * @param taskMsg
	 * @return false if the msg could not be added due to buffer size limit;
	 *         true otherwise
	 */
	boolean tryAdd(TaskMessage taskMsg) {
		if ((encoded_length + msgEncodeLength(taskMsg)) > buffer_size)
			return false;
		add(taskMsg);
		return true;
	}

	private int msgEncodeLength(TaskMessage taskMsg) {
		if (taskMsg == null)
			return 0;

		int size = 6; // INT + SHORT
		if (taskMsg.message() != null)
			size += taskMsg.message().length;
		return size;
	}

	/**
	 * Has this batch used up allowed buffer size
	 * 
	 * @return
	 */
	boolean isFull() {
		return encoded_length >= buffer_size;
	}

	/**
	 * true if this batch doesn't have any messages
	 * 
	 * @return
	 */
	boolean isEmpty() {
		return msgs.isEmpty();
	}

	/**
	 * # of msgs in this batch
	 * 
	 * @return
	 */
	int size() {
		return msgs.size();
	}
	
	

	public int getEncoded_length() {
		return encoded_length;
	}

	/**
	 * create a buffer containing the encoding of this batch
	 */
	ChannelBuffer buffer() throws Exception {
		ChannelBufferOutputStream bout = new ChannelBufferOutputStream(
				ChannelBuffers.directBuffer(encoded_length));

		for (Object msg : msgs)
			if (msg instanceof TaskMessage)
				writeTaskMessage(bout, (TaskMessage) msg);
			else {
				// LOG.debug("Write one non-TaskMessage {}", msg );
				((ControlMessage) msg).write(bout);
			}

		// add a END_OF_BATCH indicator
		ControlMessage.EOB_MESSAGE.write(bout);
		// LOG.debug("ControlMessage.EOB_MESSAGE " );

		bout.close();

		return bout.buffer();
	}

	/**
	 * write a TaskMessage into a stream
	 * 
	 * Each TaskMessage is encoded as: task ... short(2) len ... int(4) payload
	 * ... byte[] *
	 */
	private void writeTaskMessage(ChannelBufferOutputStream bout,
			TaskMessage message) throws Exception {
		int payload_len = 0;
		if (message.message() != null)
			payload_len = message.message().length;

		int task_id = message.task();
		if (task_id > Short.MAX_VALUE)
			throw new RuntimeException("Task ID should not exceed "
					+ Short.MAX_VALUE);

		bout.writeShort((short) task_id);
		bout.writeInt(payload_len);
		if (payload_len > 0)
			bout.write(message.message());

		// @@@ TESTING CODE
		// LOG.info("Write one message taskid:{}, len:{}, data:{}", task_id
		// , payload_len, JStormUtils.toPrintableString(message.message()) );
	}
}